﻿// Accord Statistics Library
// The Accord.NET Framework
// http://accord-framework.net
//
// Copyright © César Souza, 2009-2017
// cesarsouza at gmail.com
//
//    This library is free software; you can redistribute it and/or
//    modify it under the terms of the GNU Lesser General Public
//    License as published by the Free Software Foundation; either
//    version 2.1 of the License, or (at your option) any later version.
//
//    This library is distributed in the hope that it will be useful,
//    but WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
//    Lesser General Public License for more details.
//
//    You should have received a copy of the GNU Lesser General Public
//    License along with this library; if not, write to the Free Software
//    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
//

namespace Accord.Statistics.Filters
{
    using System;
    using System.Data;
    using System.Runtime.Serialization;
    using Accord.Math;
    using Accord.MachineLearning;
    using System.Collections.Generic;
    using Accord.Compat;

    /// <summary>
    ///   Strategies for missing value imputations.
    /// </summary>
    /// 
    /// <seealso cref="Imputation"/>
    /// <seealso cref="Imputation{T}"/>
    /// 
    public enum ImputationStrategy
    {
        /// <summary>
        ///   Uses a fixed-value to replace missing fields.
        /// </summary>
        /// 
        FixedValue,

        /// <summary>
        ///   Uses the mean value to replace missing fields.
        /// </summary>
        /// 
        Mean,

        /// <summary>
        ///   Uses the mode value to replace missing fields.
        /// </summary>
        /// 
        Mode,

        /// <summary>
        ///   Uses the median value to replace missing fields.
        /// </summary>
        /// 
        Median
    };

    /// <summary>
    ///   Imputation filter for filling missing values.
    /// </summary>
    /// 
    /// <example>
    ///   <code source="Unit Tests\Accord.Tests.Statistics\Filters\ImputationFilterTest.cs" region="doc_learn" />
    /// </example>
    /// 
    [Serializable]
    public class Imputation : Imputation<object>
    {
        /// <summary>
        ///   Creates a new Imputation filter.
        /// </summary>
        /// 
        public Imputation()
            : base() { }

        /// <summary>
        ///   Creates a new Imputation filter.
        /// </summary>
        /// 
        public Imputation(string[] names, object[][] data)
            : base(names, data)
        {
        }

        /// <summary>
        ///   Creates a new Imputation filter.
        /// </summary>
        /// 
        public Imputation(object[][] data)
            : base(data)
        {
        }

        /// <summary>
        ///   Creates a new Imputation filter.
        /// </summary>
        /// 
        public Imputation(params string[] columns)
            : base(columns)
        {
        }
    }

    /// <summary>
    ///   Imputation filter for filling missing values.
    /// </summary>
    /// 
    /// <example>
    ///   <code source="Unit Tests\Accord.Tests.Statistics\Filters\ImputationFilterTest.cs" region="doc_learn" />
    /// </example>
    /// 
    [Serializable]
    public class Imputation<T> : BaseFilter<Imputation<T>.Options, Imputation<T>>,
        IAutoConfigurableFilter, ITransform<T[], T[]>,
        IUnsupervisedLearning<Imputation<T>, T[], T[]>
    {
        /// <summary>
        ///   Gets the number of outputs generated by the model.
        /// </summary>
        /// 
        /// <value>The number of outputs.</value>
        /// 
        public int NumberOfOutputs { get { return Columns.Count; } }

        /// <summary>
        ///   Creates a new Imputation filter.
        /// </summary>
        /// 
        public Imputation()
            : base() { }

        /// <summary>
        ///   Creates a new Imputation filter.
        /// </summary>
        /// 
        public Imputation(T[][] data)
        {
            this.Learn(data);
        }

        /// <summary>
        ///   Creates a new Imputation filter.
        /// </summary>
        /// 
        public Imputation(string[] columnNames, T[][] data)
        {
            foreach (String col in columnNames)
                Columns.Add(new Options(col));

            this.Learn(data);
        }

        /// <summary>
        ///   Creates a new Imputation filter.
        /// </summary>
        /// 
        public Imputation(params string[] columnNames)
        {
            foreach (String col in columnNames)
                Columns.Add(new Options(col));
        }

        /// <summary>
        /// Applies the transformation to an input, producing an associated output.
        /// </summary>
        /// <param name="input">The input data to which the transformation should be applied.</param>
        /// <returns>The output generated by applying this transformation to the given input.</returns>
        public T[] Transform(T[] input)
        {
            return Transform(new[] { input })[0];
        }

        /// <summary>
        /// Applies the transformation to a set of input vectors,
        /// producing an associated set of output vectors.
        /// </summary>
        /// <param name="input">The input data to which
        /// the transformation should be applied.</param>
        /// <returns>The output generated by applying this
        /// transformation to the given input.</returns>
        public T[][] Transform(T[][] input)
        {
            return Transform(input, Jagged.CreateAs(input));
        }

        /// <summary>
        /// Applies the transformation to a set of input vectors,
        /// producing an associated set of output vectors.
        /// </summary>
        /// <param name="input">The input data to which
        /// the transformation should be applied.</param>
        /// <param name="result">The location to where to store the
        /// result of this transformation.</param>
        /// <returns>The output generated by applying this
        /// transformation to the given input.</returns>
        public T[][] Transform(T[][] input, T[][] result)
        {
            for (int i = 0; i < input.Length; i++)
            {
                for (int j = 0; j < input[i].Length; j++)
                {
                    Options options = Columns[j];

                    if (options.ReplaceWith != null && options.IsMissingValue(input[i][j]))
                    {
                        result[i][j] = Columns[j].ReplaceWith;
                    }
                    else
                    {
                        result[i][j] = input[i][j];
                    }
                }
            }

            return result;
        }

#if !NETSTANDARD1_4
        /// <summary>
        ///   Processes the current filter.
        /// </summary>
        /// 
        protected override DataTable ProcessFilter(DataTable data)
        {
            // Copy the DataTable
            DataTable result = data.Copy();

            foreach (DataRow row in result.Rows)
            {
                foreach (Options options in Columns)
                {
                    if (options.ReplaceWith != null && options.IsMissingValue(row[options.ColumnName]))
                    {
                        row[options.ColumnName] = options.ReplaceWith;
                    }
                }
            }

            return result;
        }

        /// <summary>
        ///   Auto detects the filter options by analyzing a given <see cref="System.Data.DataTable"/>.
        /// </summary> 
        /// 
        public void Detect(DataTable data)
        {
            Learn(data);
        }

        /// <summary>
        ///   Learns a model that can map the given inputs to the desired outputs.
        /// </summary>
        /// 
        /// <param name="x">The model inputs.</param>
        /// <param name="weights">The weight of importance for each input sample.</param>
        /// 
        /// <returns>A model that has learned how to produce suitable outputs
        ///    given the input data <paramref name="x" />.</returns>
        /// 
        /// <exception cref="ArgumentException">weights</exception>
        /// <exception cref="Exception">There are more predefined columns than columns in the data.</exception>
        /// 
        public Imputation<T> Learn(DataTable x, double[] weights = null)
        {
            if (weights != null)
                throw new ArgumentException(Accord.Properties.Resources.NotSupportedWeights, "weights");

            foreach (DataColumn col in x.Columns)
            {
                if (!this.Columns.Contains(col.ColumnName))
                {
                    Columns.Add(new Options(col.ColumnName));
                }
            }

            foreach (DataColumn col in x.Columns)
            {
                this.Columns[col.ColumnName].Learn(x, weights);
            }

            return this;
        }
#endif

        /// <summary>
        ///   Learns a model that can map the given inputs to the desired outputs.
        /// </summary>
        /// 
        /// <param name="x">The model inputs.</param>
        /// <param name="weights">The weight of importance for each input sample.</param>
        /// 
        /// <returns>A model that has learned how to produce suitable outputs
        ///    given the input data <paramref name="x" />.</returns>
        /// 
        /// <exception cref="ArgumentException">weights</exception>
        /// <exception cref="Exception">There are more predefined columns than columns in the data.</exception>
        /// 
        public Imputation<T> Learn(T[][] x, double[] weights = null)
        {
            if (weights != null)
                throw new ArgumentException(Accord.Properties.Resources.NotSupportedWeights, "weights");

            for (int i = this.Columns.Count; i < x.Columns(); i++)
                this.Columns.Add(new Options(i.ToString()));

            if (this.Columns.Count != x.Columns())
                throw new Exception("There are more predefined columns than columns in the data.");

            for (int i = 0; i < Columns.Count; i++)
                Columns[i].Learn(x.GetColumn(i), weights);

            return this;
        }


        int ITransform.NumberOfOutputs
        {
            get { return NumberOfOutputs; }
            set { throw new InvalidOperationException("This property is read-only."); }
        }

        /// <summary>
        ///   Options for the imputation filter.
        /// </summary>
        /// 
        [Serializable]
        public class Options : ColumnOptionsBase<Imputation<T>>, IAutoConfigurableColumn
        {

            [OptionalField]
            private ImputationStrategy strategy;

            /// <summary>
            ///   Gets or sets the imputation strategy
            ///   to use with this column.
            /// </summary>
            public ImputationStrategy Strategy
            {
                get { return strategy; }
                set { strategy = value; }
            }

            /// <summary>
            ///   Missing value indicator.
            /// </summary>
            /// 
            public T MissingValue { get; set; }

            /// <summary>
            ///   Value to replace missing values with.
            /// </summary>
            /// 
            public T ReplaceWith { get; set; }

            /// <summary>
            ///   Constructs a new column option
            ///   for the Imputation filter.
            /// </summary>
            /// 
            public Options(String name)
                : base(name)
            {
                if (typeof(T) == typeof(Double))
                    this.MissingValue = (Double.NaN).To<T>();
                else if (typeof(T) == typeof(Single))
                    this.MissingValue = (Single.NaN).To<T>();
                else if (typeof(T) == typeof(int))
                    this.MissingValue = (-1).To<T>();
                else
                    this.MissingValue = default(T);

                this.ReplaceWith = default(T);
            }

            /// <summary>
            ///   Constructs a new column option
            ///   for the Imputation filter.
            /// </summary>
            /// 
            public Options()
                : this("New column") { }

#if !NETSTANDARD1_4
            /// <summary>
            ///   Auto detects the column options by analyzing
            ///   a given <see cref="System.Data.DataColumn"/>.
            /// </summary>
            /// 
            /// <param name="column">The column to analyze.</param>
            /// 
            public void Detect(DataColumn column)
            {
                Learn(column.To<T[]>());
            }

            /// <summary>
            ///   Learns a model that can map the given inputs to the desired outputs.
            /// </summary>
            /// 
            /// <param name="x">The model inputs.</param>
            /// <param name="weights">The weight of importance for each input sample.</param>
            /// 
            /// <returns>A model that has learned how to produce suitable outputs
            ///    given the input data <paramref name="x" />.</returns>
            /// 
            /// <exception cref="ArgumentException">weights</exception>
            /// <exception cref="Exception">There are more predefined columns than columns in the data.</exception>
            /// 
            public Options Learn(DataTable x, double[] weights = null)
            {
                if (weights != null)
                    throw new ArgumentException(Accord.Properties.Resources.NotSupportedWeights, "weights");

                return Learn(x.Columns[this.ColumnName]);
            }

            /// <summary>
            ///   Learns a model that can map the given inputs to the desired outputs.
            /// </summary>
            /// 
            /// <param name="x">The model inputs.</param>
            /// <param name="weights">The weight of importance for each input sample.</param>
            /// 
            /// <returns>A model that has learned how to produce suitable outputs
            ///    given the input data <paramref name="x" />.</returns>
            /// 
            /// <exception cref="ArgumentException">weights</exception>
            /// <exception cref="Exception">There are more predefined columns than columns in the data.</exception>
            /// 
            private Options Learn(DataColumn x, double[] weights = null)
            {
                Type type = x.DataType;
                SetDefaultMissingValue(type);

                if (strategy == ImputationStrategy.FixedValue)
                    return this;

                if (strategy == ImputationStrategy.Mode)
                {
                    object[] values = filter<object>(x);
                    ReplaceWith = values.Mode().To<T>();
                    return this;
                }

                if (type == typeof(double))
                {
                    double[] doubleColumn = filter<double>(x);
                    switch (Strategy)
                    {
                        case ImputationStrategy.Mean:
                            ReplaceWith = doubleColumn.Mean().To<T>();
                            break;
                        case ImputationStrategy.Median:
                            ReplaceWith = doubleColumn.Median().To<T>();
                            break;
                    }
                }
                else if (type == typeof(int))
                {
                    int[] intColumn = filter<int>(x);

                    switch (Strategy)
                    {
                        case ImputationStrategy.Mean:
                            ReplaceWith = intColumn.Mean().To<T>();
                            break;
                        case ImputationStrategy.Median:
                            ReplaceWith = intColumn.Median().To<T>();
                            break;
                    }
                }
                else
                {
                    throw new NotSupportedException("The imputation strategy {0} is not supported for values of type {1}".Format(strategy, typeof(T)));
                }

                return this;
            }
#endif

            /// <summary>
            ///   Learns a model that can map the given inputs to the desired outputs.
            /// </summary>
            /// 
            /// <param name="x">The model inputs.</param>
            /// <param name="weights">The weight of importance for each input sample.</param>
            /// 
            /// <returns>A model that has learned how to produce suitable outputs
            ///    given the input data <paramref name="x" />.</returns>
            /// 
            /// <exception cref="ArgumentException">weights</exception>
            /// <exception cref="Exception">There are more predefined columns than columns in the data.</exception>
            /// 
            public Options Learn(T[] x, double[] weights = null)
            {
                if (weights != null)
                    throw new ArgumentException(Accord.Properties.Resources.NotSupportedWeights, "weights");

                Type type = x[0].GetType();
                SetDefaultMissingValue(type);

                if (strategy == ImputationStrategy.FixedValue)
                    return this;

                x = filter(x);

                if (strategy == ImputationStrategy.Mode)
                {
                    ReplaceWith = x.Mode();
                    return this;
                }

                if (!compute(x))
                {
                    if (type == typeof(double))
                        compute(x.To<double[]>());
                    else if (type == typeof(int))
                        compute(x.To<int[]>());
                    else
                        throw new NotSupportedException("The imputation strategy {0} is not supported for values of type {1}".Format(strategy, typeof(T)));
                }

                return this;
            }

            private void SetDefaultMissingValue(Type type)
            {
                if (typeof(T) != type && this.MissingValue == null)
                {
                    if (type == typeof(int))
                    {
                        this.MissingValue = (-1).To<T>();
                    }
                    else if (type == typeof(double))
                    {
                        this.MissingValue = (Double.NaN).To<T>();
                    }
                    else if (type == typeof(float))
                    {
                        this.MissingValue = (Single.NaN).To<T>();
                    }
                    else
                    {
#if NETSTANDARD1_4
                        this.MissingValue = default(T);
#else
                        this.MissingValue = type.GetDefaultValue().To<T>();
#endif
                    }
                }
            }

            /// <summary>
            ///   Determines whether the given object denotes a missing value.
            /// </summary>
            /// 
            public bool IsMissingValue(object value)
            {
#if !NETSTANDARD1_4
                return (value is DBNull || value == null || Object.Equals(this.MissingValue, value));
#else
                return (value == null || Object.Equals(this.MissingValue, value));
#endif
            }

            private bool compute(object column)
            {
                double[] doubleColumn = column as double[];
                if (doubleColumn != null)
                {
                    switch (Strategy)
                    {
                        case ImputationStrategy.Mean:
                            ReplaceWith = doubleColumn.Mean().To<T>();
                            break;
                        case ImputationStrategy.Median:
                            ReplaceWith = doubleColumn.Median().To<T>();
                            break;
                    }

                    return true;
                }

                int[] intColumn = column as int[];
                if (intColumn != null)
                {
                    switch (Strategy)
                    {
                        case ImputationStrategy.Mean:
                            ReplaceWith = intColumn.Mean().To<T>();
                            break;
                        case ImputationStrategy.Median:
                            ReplaceWith = intColumn.Median().To<T>();
                            break;
                    }
                    return true;
                }

                return false;
            }

#if !NETSTANDARD1_4
            private TValue[] filter<TValue>(DataColumn column)
            {
                var m = new List<TValue>();

                foreach (DataRow row in column.Table.Rows)
                {
                    object value = row[column];

                    if (Object.Equals(value, this.MissingValue) || value is DBNull)
                        continue;

                    m.Add((TValue)System.Convert.ChangeType(value, typeof(TValue)));
                }

                return m.ToArray();
            }
#endif

            private TValue[] filter<TValue>(TValue[] column)
            {
                var m = new List<TValue>();

                foreach (TValue value in column)
                {
                    if (Object.Equals(value, this.MissingValue) || value is DBNull)
                        continue;

                    m.Add((TValue)System.Convert.ChangeType(value, typeof(TValue)));
                }

                return m.ToArray();
            }
        }
    }
}
